#!/usr/bin/env python3
"""
stability_check.py

Compute stability & normalization metrics from multiple per-seed rate files.
Writes JSON and prints a one-liner (the exact string you want to paste back).
"""
import argparse, json
from pathlib import Path
import numpy as np
import pandas as pd

CORE = ["rate_IN_to_CS","rate_CS_to_ON","rate_ON_to_CS","rate_CS_to_IN"]
ROWS = ["rowsum_IN","rowsum_CS","rowsum_ON"]

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--rates", nargs="+", required=True, help="paths to per-seed flip_rates_by_context.csv")
    ap.add_argument("--out", default="results/stability_metrics.json")
    ap.add_argument("--coverage_from", default="D_values.csv")
    args = ap.parse_args()

    frames = [pd.read_csv(p) for p in args.rates]
    df = pd.concat(frames, ignore_index=True)

    # stability
    max_delta = 0.0
    for n, grp in df.groupby("n"):
        for c in CORE:
            vals = grp[c].values
            if len(vals) >= 2:
                d = float(np.max(vals) - np.min(vals))
                max_delta = max(max_delta, d)

    # normalization
    max_resid = float(np.max(np.abs(df[ROWS].values - 1.0)))

    # coverage
    coverage = sorted([float(x) for x in df["n"].unique()])
    expected, missing, extra = None, [], []
    want_path = Path(args.coverage_from)
    if want_path.exists():
        dw = pd.read_csv(want_path)
        if "n" in dw.columns:
            expected = sorted([float(x) for x in dw["n"].unique()])
            have = set(coverage)
            missing = [x for x in expected if x not in have]
            extra   = [x for x in coverage if x not in set(expected)]

    metrics = {
        "max_delta_rate": max_delta,
        "max_rowsum_residual": max_resid,
        "coverage_n": coverage,
        "coverage_expected": expected,
        "coverage_missing": missing,
        "coverage_extra": extra,
        "num_seeds": int(df["seed"].nunique()) if "seed" in df.columns else 1
    }
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    with open(args.out, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"max Δrate across seeds = {max_delta:.4f}; max |row-sum−1| = {max_resid:.3e}; seeds = {metrics['num_seeds']}")

if __name__ == "__main__":
    main()
